/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/backends/gpu/autotuner/cublas.h"

#include <memory>
#include <utility>
#include <vector>

#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/autotuning.pb.h"
#include "xla/backends/autotuner/codegen_backend.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/transforms/dot_algorithm_rewriter.h"
#include "xla/service/gpu/transforms/gemm_rewriter.h"
#include "xla/service/gpu/transforms/priority_fusion.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/stream_executor/blas.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/gpu/gpu_blas_lt.h"
#include "xla/stream_executor/stream_executor_memory_allocator.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"

namespace xla {
namespace gpu {

namespace se = ::stream_executor;

absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>>
CublasBackend::GetSupportedConfigs(const HloInstruction& instr) {
  if (!IsLegacyCublasMatmul(instr)) {
    return std::vector<std::unique_ptr<BackendConfig>>();
  }

  std::unique_ptr<se::DeviceMemoryAllocator> allocator =
      std::make_unique<se::StreamExecutorMemoryAllocator>(stream_executor());
  TF_ASSIGN_OR_RETURN(
      se::Stream * stream,
      allocator->GetStream(stream_executor()->device_ordinal()));

  // We use GemmConfig::For with GemmBackendConfig as a fallback because
  // Matmul_utils.cc relies on backend config to determine gemm contracting
  // dimensions.
  GemmBackendConfig backend_config;
  backend_config =
      instr.backend_config<GpuBackendConfig>()->gemm_backend_config();
  TF_ASSIGN_OR_RETURN(
      GemmConfig gemm_config,
      GemmConfig::For(
          &instr, backend_config,
          target_config().device_description.gpu_compute_capability()));

  auto create_matrix_desc = [](const se::gpu::MatrixLayout& layout)
      -> absl::StatusOr<se::gpu::MatrixDescriptor> {
    TF_ASSIGN_OR_RETURN(se::blas::DataType type,
                        se::gpu::AsBlasDataType(layout.dtype));
    return se::gpu::MatrixDescriptor{
        /*data=*/se::DeviceMemoryBase(), layout.leading_dim_stride,
        layout.batch_stride, type,
        // BLAS is column-major by default.
        (layout.order == se::gpu::MatrixLayout::Order::kColumnMajor
             ? se::blas::Transpose::kNoTranspose
             : se::blas::Transpose::kTranspose)};
  };

  TF_ASSIGN_OR_RETURN(se::gpu::MatrixDescriptor lhs_desc,
                      create_matrix_desc(gemm_config.lhs_layout));
  TF_ASSIGN_OR_RETURN(se::gpu::MatrixDescriptor rhs_desc,
                      create_matrix_desc(gemm_config.rhs_layout));
  TF_ASSIGN_OR_RETURN(se::gpu::MatrixDescriptor output_desc_base,
                      create_matrix_desc(gemm_config.output_layout));

  se::gpu::OutputMatrixDescriptor out_desc(std::move(output_desc_base));
  out_desc.batch_size = gemm_config.output_layout.batch_size;
  out_desc.m = gemm_config.output_layout.num_rows;
  out_desc.n = gemm_config.output_layout.num_cols;
  out_desc.k = gemm_config.lhs_layout.num_cols;
  TF_ASSIGN_OR_RETURN(
      out_desc.compute_type,
      se::gpu::GetBlasComputationType(
          gemm_config.precision_algorithm, gemm_config.lhs_layout.dtype,
          gemm_config.output_layout.dtype, gemm_config.compute_precision));

  se::blas::BlasSupport* blas = stream_executor()->AsBlas();
  if (blas == nullptr) {
    return absl::InternalError("Failed to getBlas support.");
  }
  std::vector<se::blas::AlgorithmType> algorithms;

  blas->GetBlasGemmAlgorithms(stream, lhs_desc, rhs_desc, &out_desc,
                              &gemm_config.alpha, &gemm_config.beta,
                              &algorithms);

  std::vector<std::unique_ptr<BackendConfig>> configs;
  configs.reserve(algorithms.size());
  for (se::blas::AlgorithmType algorithm : algorithms) {
    AutotuneResult::GemmKey gemm_key;
    gemm_key.set_algorithm(algorithm);
    auto any = std::make_unique<google::protobuf::Any>();
    any->PackFrom(gemm_key);
    configs.push_back(std::move(any));
  }
  return configs;
}

namespace {
HloCostAnalysis::Options PriorityFusionOptions() {
  // The real pointer size is set in GpuCompiler. In HloCostAnalysis, the
  // pointer size is used only to determine the size of tuple types. We
  // shouldn't have any tuples in the autotuned module, so it's safe to use
  // the default value here, instead of piping the real value.
  HloCostAnalysis::Options options;
  options.count_multiple_input_accesses = true;
  return options;
}
}  // namespace

absl::StatusOr<std::unique_ptr<HloModule>> RewriteToCublasCustomCall(
    std::unique_ptr<HloModule> hlo_module,
    const se::DeviceDescription& gpu_device_info) {
  HloInstruction* dot = hlo_query::GetFirstInstructionWithOpcode(
      *hlo_module->entry_computation(), HloOpcode::kDot);
  // Substitute algorithms, which are not supported by cuBLAS for the check, but
  // don't use cuBlas in the end. This assumes that the substituting algorithm
  // has result which are close enough for the check in this file.
  if (dot->precision_config().algorithm() ==
      PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3) {
    dot->mutable_precision_config()->set_algorithm(
        PrecisionConfig::ALG_DOT_F32_F32_F32);
  }

  for (GemmRewriterOptions::DType dtype :
       {GemmRewriterOptions::DType::kFp8Only,
        GemmRewriterOptions::DType::kNonFp8Only}) {
    GemmRewriter gemm_rewriter(gpu_device_info.cuda_compute_capability(),
                               gpu_device_info.runtime_version(),
                               GemmRewriterOptions{dtype});
    DotAlgorithmRewriter dot_algorithm_rewriter;
    PriorityFusion fusion_pass(
        /*thread_pool=*/nullptr, gpu_device_info, PriorityFusionOptions());
    TF_RETURN_IF_ERROR(dot_algorithm_rewriter.Run(hlo_module.get()).status());
    TF_RETURN_IF_ERROR(gemm_rewriter.Run(hlo_module.get()).status());
    TF_RETURN_IF_ERROR(fusion_pass.Run(hlo_module.get()).status());
  }

  return hlo_module;
}

absl::StatusOr<std::unique_ptr<BackendConfig>> CublasBackend::GetDefaultConfig(
    const HloInstruction& instr) {
  if (!IsLegacyCublasMatmul(instr)) {
    return absl::InvalidArgumentError(
        "CublasBackend does not support this instruction.");
  }

  AutotuneResult::GemmKey gemm_key;
  gemm_key.set_algorithm(se::blas::kDefaultAlgorithm);
  auto any = std::make_unique<google::protobuf::Any>();
  any->PackFrom(gemm_key);
  return any;
}

absl::Status CublasBackend::ApplyConfig(HloInstruction& instr,
                                        const BackendConfig& config) {
  AutotuneResult::GemmKey gemm_key;
  if (!config.UnpackTo(&gemm_key)) {
    return absl::InvalidArgumentError(
        "Failed to unpack CublasBackendConfig from Any.");
  }
  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
                      instr.backend_config<GpuBackendConfig>());
  GemmBackendConfig& backend_config = *gpu_config.mutable_gemm_backend_config();
  backend_config.set_selected_algorithm(gemm_key.algorithm());
  TF_RETURN_IF_ERROR(instr.set_backend_config(std::move(gpu_config)));
  return absl::OkStatus();
}

}  // namespace gpu
}  // namespace xla
